package com.diven.spark.ml.learn.feature

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import com.diven.spark.ml.learn.core.BaseTest
import com.diven.spark.ml.learn.core.BaseSpark
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.feature.DCT
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.StringIndexerModel

object StringIndexerTest extends BaseTest {
    
    def apply(): BaseSpark = new StringIndexerTest()
    
}

class StringIndexerTest extends BaseSpark{
    
     override def getDataFrame(sparkSession: SparkSession = this.getSparkSession()): DataFrame = {
         sparkSession.createDataFrame(Seq(
              (0, "a"), (1, "b"), 
              (2, "c"), (3, "a"), 
              (4, "a"), (5, "c")
         ))
         .toDF("id", "category")
     }
    
    override def execute(dataFrame: DataFrame) = {
        //设置模型
        val indexer = new StringIndexer()
        .setInputCol("category")
        .setOutputCol("categoryIndex")
        //获取模型
        var model: StringIndexerModel = indexer.fit(dataFrame)
        //映射关系
        println(model.labels.mkString(","))
        //转换数据
        val indexed = model.transform(dataFrame)
        //show
        indexed.show()
    }
    
}
